{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "A Tour of Oryx",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZrwVQsM9TiUw"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Probability Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "CpDUTVKYTowI",
        "colab": {}
      },
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ltPJCG6pAUoc"
      },
      "source": [
        "# A Tour of Oryx\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/probability/oryx/examples/a_tour_of_oryx\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/probability/blob/master/spinoffs/oryx/examples/notebooks/a_tour_of_oryx.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/probability/blob/master/spinoffs/oryx/examples/notebooks/a_tour_of_oryx.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/probability/spinoffs/oryx/examples/notebooks/a_tour_of_oryx.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Cvrh7Ppuwlbb"
      },
      "source": [
        "## What is Oryx?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "F_n9c7K3xdKQ"
      },
      "source": [
        "Oryx is an experimental library that extends [JAX](https://github.com/google/jax) to applications ranging from building and training complex neural networks to approximate Bayesian inference in deep generative models. Like JAX provides `jit`, `vmap`, and `grad`, Oryx provides a set of **composable function transformations** that enable writing simple code and transforming it to build complexity while staying completely interoperable with JAX.\n",
        "\n",
        "JAX can only safely transform pure, functional code (i.e. code without side-effects). While pure code can be easier to write and reason about, \"impure\" code can often be more concise and more easily expressive.\n",
        "\n",
        "At its core, Oryx is a library that enables \"augmenting\" pure functional code to accomplish tasks like defining state or pulling out intermediate values. Its goal is to be as thin of a layer on top of JAX as possible, leveraging JAX's minimalist approach to numerical computing. Oryx is conceptually divided into several \"layers\", each building on the one below it.\n",
        "\n",
        "The source code for Oryx can be found [on GitHub](https://github.com/tensorflow/probability/tree/master/spinoffs/oryx)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8cloSFmOiJqn"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cdhNEzj6iJCc",
        "colab": {}
      },
      "source": [
        "!pip install oryx 1>/dev/null\n",
        "!pip install jaxlib --upgrade 1>/dev/null"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "Ve8yVrLbiOXv",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "sns.set(style='whitegrid')\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax import random\n",
        "from jax import vmap\n",
        "from jax import jit\n",
        "from jax import grad\n",
        "\n",
        "import oryx\n",
        "\n",
        "tfd = oryx.distributions\n",
        "\n",
        "state = oryx.core.state\n",
        "ppl = oryx.core.ppl\n",
        "\n",
        "inverse = oryx.core.inverse\n",
        "ildj = oryx.core.ildj\n",
        "plant = oryx.core.plant\n",
        "reap = oryx.core.reap\n",
        "sow = oryx.core.sow\n",
        "unzip = oryx.core.unzip\n",
        "\n",
        "nn = oryx.experimental.nn\n",
        "mcmc = oryx.experimental.mcmc\n",
        "optimizers = oryx.experimental.optimizers"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AF05PEzd8QFI"
      },
      "source": [
        "## Layer 0: Base function transformations\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8WVTh54ZBJvq"
      },
      "source": [
        "At its base, Oryx defines several new function transformations. These transformations are implemented using JAX's tracing machinery and are interoperable with existing JAX transformations like `jit`, `grad`, `vmap`, etc.\n",
        "\n",
        "### Automatic function inversion\n",
        "`oryx.core.inverse` and `oryx.core.ildj` are function transformations that can programatically invert a function and compute its inverse log-det Jacobian (ILDJ) respectively. These transformations are useful in probabilistic modeling for computing log-probabilities using the change-of-variable formula. There are limitations on the types of functions they are compatible with, however (see [the documentation](https://tensorflow.org/probability/oryx/api_docs/python/oryx/core/interpreters/inverse) for more details)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "YxbReBYs5OpM",
        "colab": {}
      },
      "source": [
        "def f(x):\n",
        "  return jnp.exp(x) + 2.\n",
        "print(inverse(f)(4.))  # ln(2)\n",
        "print(ildj(f)(4.)) # -ln(2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-U08JAgs5w5p"
      },
      "source": [
        "### Harvest\n",
        "`oryx.core.harvest` enables tagging values in functions along with the ability to collect them, or \"reap\" them, and the ability to inject values in their place, or \"planting\" them. We tag values using the `sow` function."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "pFJNr4SR5_vl",
        "colab": {}
      },
      "source": [
        "def f(x):\n",
        "  y = sow(x + 1., name='y', tag='intermediate')\n",
        "  return y ** 2\n",
        "print('Reap:', reap(f, tag='intermediate')(1.))  # Pulls out 'y'\n",
        "print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.))  # Injects 5. for 'y'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ffR6Emmm5OVI"
      },
      "source": [
        "### Unzip\n",
        "`oryx.core.unzip` splits a function in two along a set of values tagged as intermediates, then returning the functions `init_f` and `apply_f`. `init_f` takes in a key argument and returns the intermediates. `apply_f` returns a function that takes in the intermediates and returns the original function's output."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "ojFVr_ZKm0UX",
        "colab": {}
      },
      "source": [
        "def f(key, x):\n",
        "  w = sow(random.normal(key), tag='variable', name='w')\n",
        "  return w * x\n",
        "init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jUJ5isbLjGy8",
        "colab_type": "text"
      },
      "source": [
        "The `init_f` function runs `f` but only returns its variables."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "26VUK0nTjLcO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "init_f(random.PRNGKey(0))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0KWemKR2jOn6",
        "colab_type": "text"
      },
      "source": [
        "`apply_f` takes a set of variables as its first input and executes `f` with the given set of variables."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SpKFfQZqiDAR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "apply_f(dict(w=2.), 2.)  # Runs f with `w = 2`.\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Q0EtM2bj64fc"
      },
      "source": [
        "## Layer 1: Higher level transformations"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DexZ_6Ds69J4"
      },
      "source": [
        "Oryx builds off the low-level inverse, harvest, and unzip function transformations to offer several higher-level transformations for writing stateful computations and for probabilistic programming."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3zEucvAN7WJX"
      },
      "source": [
        "### Stateful functions (`core.state`)\n",
        "We're often interested in expressing stateful computations where we initialize a set of parameters and express a computation in terms of the parameters. In `oryx.core.state`, Oryx provides an `init` transformation that converts a function into one that initializes a `Module`, a container for state.\n",
        "\n",
        "`Module`s resemble Pytorch and TensorFlow `Module`s except that they are immutable."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "cmV2jLSr62Le",
        "colab": {}
      },
      "source": [
        "def make_dense(dim_out):\n",
        "  def forward(x, init_key=None):\n",
        "    w_key, b_key = random.split(init_key)\n",
        "    dim_in = x.shape[0]\n",
        "    w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')\n",
        "    b = state.variable(random.normal(w_key, (dim_out,)), name='b')\n",
        "    return jnp.dot(x, w) + b\n",
        "  return forward\n",
        "\n",
        "layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))\n",
        "print('layer:', layer)\n",
        "print('layer.w:', layer.w)\n",
        "print('layer.b:', layer.b)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dM02YafPiyVR",
        "colab_type": "text"
      },
      "source": [
        "`Module`s are registered as JAX pytrees and can be used as inputs to JAX transformed functions. Oryx provides a convenient `call` function that executes a `Module`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DRYp96JFizoU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "p_ZPPibD-NI4"
      },
      "source": [
        "The `state` API also enables writing stateful updates (like running averages) using the `assign` function. The resulting `Module` has an `update` function with an input signature that is the same as the `Module`'s `__call__` but creates a new copy of the `Module` with an updated state."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fXnL3ZvD-UKx",
        "colab": {}
      },
      "source": [
        "def counter(x, init_key=None):\n",
        "  count = state.variable(0., key=init_key, name='count')\n",
        "  count = state.assign(count + 1., name='count')\n",
        "  return x + count\n",
        "layer = state.init(counter)(random.PRNGKey(0), 0.)\n",
        "print(layer.count)\n",
        "updated_layer = layer.update(0.)\n",
        "print(updated_layer.count) # Count has advanced!\n",
        "print(updated_layer.call(1.))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VO_VdtAA70EN"
      },
      "source": [
        "\n",
        "### Probabilistic programming"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-bYaYxDA-5yz"
      },
      "source": [
        "In `oryx.core.ppl`, Oryx provides a set of tools built on top of `harvest` and `inverse` which aim to make writing and transforming probabilistic programs intuitive and easy.\n",
        "\n",
        "In Oryx, a probabilistic program is a JAX function that takes a source of randomness as its first argument and returns a sample from a distribution, i.e, `f :: Key -> Sample`. In order to write these programs, Oryx wraps [TensorFlow Probability](https://www.tensorflow.org/probability) distributions and provides a simple function `random_variable` that converts a distribution into a probabilistic program."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "fh8AFQq771VJ",
        "colab": {}
      },
      "source": [
        "def sample(key):\n",
        "  return ppl.random_variable(tfd.Normal(0., 1.))(key)\n",
        "sample(random.PRNGKey(0))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JWnPjFxx_i5I"
      },
      "source": [
        "What can we do with probabilistic programs? The simplest thing would be to take a probabilistic program (i.e. a sampling function) and convert it into one that provides the log-density of a sample."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "h6U4_pAp_huX",
        "colab": {}
      },
      "source": [
        "ppl.log_prob(sample)(1.)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "51yfR5Sm2ZuD",
        "colab_type": "text"
      },
      "source": [
        "The new log-probability function is compatible with other JAX transformations like `vmap` and `grad`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "je3wggIi2Ytm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wEqAS9AfAPCh"
      },
      "source": [
        "Using the `ildj` transformation, we can compute `log_prob` of programs that invertibly transform samples."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2SGe1YZ5AUP1",
        "colab": {}
      },
      "source": [
        "def sample(key):\n",
        "  x = ppl.random_variable(tfd.Normal(0., 1.))(key)\n",
        "  return jnp.exp(x / 2.) + 2.\n",
        "_, ax = plt.subplots(2)\n",
        "ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),\n",
        "    bins='auto')\n",
        "x = jnp.linspace(0, 8, 100)\n",
        "ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "AEvnv1-__8jd"
      },
      "source": [
        "We can tag intermediate values in a probabilistic program with names and obtain joint sampling and joint log-prob functions."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "yDttqgL7_umZ",
        "colab": {}
      },
      "source": [
        "def sample(key):\n",
        "  z_key, x_key = random.split(key)\n",
        "  z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)\n",
        "  x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)\n",
        "  return x\n",
        "ppl.joint_sample(sample)(random.PRNGKey(0))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q45YW73E2uVK",
        "colab_type": "text"
      },
      "source": [
        "Oryx also has a `joint_log_prob` function that composes `log_prob` with `joint_sample`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FjZIhP7n2uwm",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "ppl.joint_log_prob(sample)(dict(x=0., z=0.))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OP8boCwYA50n"
      },
      "source": [
        "To learn more, see the [documentation](https://tensorflow.org/probability/oryx/api_docs/python/oryx/core/ppl/transformations)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "eglTKzL6A72r"
      },
      "source": [
        "## Layer 2: Mini-libraries"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9LdSK3XzBMuV"
      },
      "source": [
        "Building further on top of the layers that handle state and probabilistic programming, Oryx provides experimental mini-libraries tailored for specific applications like deep learning and Bayesian inference."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "iGXK3SHGBTqe"
      },
      "source": [
        "### Neural networks"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0l7OEJM2BYJu"
      },
      "source": [
        "In `oryx.experimental.nn`, Oryx provides a set of common neural network `Layer`s that fit neatly into the `state` API. These layers are built for single examples (not batches) but override batch behaviors to handle patterns like running averages in batch normalization. They also enable passing keyword arguments like `training=True/False` into modules.\n",
        "\n",
        "`Layer`s are initialized from a `Template` like `nn.Dense(200)` using `state.init`."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "a6c2IjijA7Sn",
        "colab": {}
      },
      "source": [
        "layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))\n",
        "print(layer, layer.params.kernel.shape, layer.params.bias.shape)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1XKSMZyuiD6v",
        "colab_type": "text"
      },
      "source": [
        "A `Layer` has a `call` method that runs its forward pass."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z0n7l3DZiNre",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "layer.call(jnp.ones(50)).shape"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "J73S0GXjCLQ2"
      },
      "source": [
        "Oryx also provides a `Serial` combinator."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "xQhmJAHVB5iN",
        "colab": {}
      },
      "source": [
        "mlp_template = nn.Serial([\n",
        "  nn.Dense(200), nn.Relu(),\n",
        "  nn.Dense(200), nn.Relu(),\n",
        "  nn.Dense(10), nn.Softmax()\n",
        "])\n",
        "# OR\n",
        "mlp_template = (\n",
        "    nn.Dense(200) >> nn.Relu()\n",
        "    >> nn.Dense(200) >> nn.Relu()\n",
        "    >> nn.Dense(10) >> nn.Softmax())\n",
        "mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))\n",
        "mlp(jnp.ones(784))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "g8h2nzyICpVd"
      },
      "source": [
        "We can interleave functions and combinators to create a flexible neural network \"meta language\"."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "NvLB8zxXChyr",
        "colab": {}
      },
      "source": [
        "def resnet(template):\n",
        "  def forward(x, init_key=None):\n",
        "    layer = state.init(template, name='layer')(init_key, x)\n",
        "    return x + layer(x)\n",
        "  return forward\n",
        "\n",
        "big_resnet_template = nn.Serial([\n",
        "  nn.Dense(50)\n",
        "  >> resnet(nn.Dense(50) >> nn.Relu())\n",
        "  >> resnet(nn.Dense(50) >> nn.Relu())\n",
        "  >> nn.Dense(10)\n",
        "])\n",
        "network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))\n",
        "network(jnp.ones(784))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7-qBbDe_D8oV"
      },
      "source": [
        "### Optimizers"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "a3c3GW1LEGKm"
      },
      "source": [
        "In `oryx.experimental.optimizers`, Oryx provides a set of first-order optimizers, built using the `state` API. Their design is based off of JAX's [`optix` library](https://jax.readthedocs.io/en/latest/jax.experimental.optix.html), where optimizers maintain state about a set of gradient updates. Oryx's version manages state using the `state` API."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "b7Gfm0d2EBC6",
        "colab": {}
      },
      "source": [
        "network_key, opt_key = random.split(random.PRNGKey(0))\n",
        "def autoencoder_loss(network, x):\n",
        "  return jnp.square(network.call(x) - x).mean()\n",
        "network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))\n",
        "opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)\n",
        "g = grad(autoencoder_loss)(network, jnp.zeros(2))\n",
        "\n",
        "g, opt = opt.call_and_update(network, g)\n",
        "network = optimizers.optix.apply_updates(network, g)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EGDs47TEFKXB"
      },
      "source": [
        "### Markov chain Monte Carlo"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "T7b6IdRwFP-k"
      },
      "source": [
        "In `oryx.experimental.mcmc`, Oryx provides a set of Markov Chain Monte Carlo (MCMC) kernels. MCMC is an approach to approximate Bayesian inference where we draw samples from a Markov chain whose stationary distribution is the posterior distribution of interest.\n",
        "\n",
        "Oryx's MCMC library builds on both the `state` and `ppl` API."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "wWTHfPWmGrAl",
        "colab": {}
      },
      "source": [
        "def model(key):\n",
        "  return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(\n",
        "      jnp.zeros(2), jnp.ones(2)))(key))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hQB7rhQ5GmN8"
      },
      "source": [
        "#### Random walk Metropolis"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "O27O2oTJE1Nu",
        "colab": {}
      },
      "source": [
        "samples = jit(mcmc.sample_chain(mcmc.metropolis(\n",
        "    ppl.log_prob(model),\n",
        "    mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))\n",
        "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "0vTY-MiTGuQa"
      },
      "source": [
        "#### Hamiltonian Monte Carlo"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "2CWSqdO7F3Ix",
        "colab": {}
      },
      "source": [
        "samples = jit(mcmc.sample_chain(mcmc.hmc(\n",
        "    ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))\n",
        "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)\n",
        "plt.show()"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
} 
